import torch
import torch.nn as nn
import torch.nn.functional as F




class Select_object(nn.Module):
    def __init__(self, queue_size=256, momentum=0.999, temperature=0.07, hidden_size = 512, select_num = 5, clip_image_encoder = None, logit_scale = None):
        '''
        MoCoV2 model, taken from: https://github.com/facebookresearch/moco.
        Adapted for use in personal Boilerplate for unsupervised/self-supervised contrastive learning.
        Additionally, too inspiration from: https://github.com/HobbitLong/CMC.
        Args:
            init:
                args (dict): Program arguments/commandline arguments.
                queue_size (int): Length of the queue/memory, number of samples to store in memory. (default: 65536)
                momentum (float): Momentum value for updating the key_encoder. (default: 0.999)
                temperature (float): Temperature used in the InfoNCE / NT_Xent contrastive losses. (default: 0.07)
            forward:
                x_q (Tensor): Reprentation of view intended for the query_encoder.
                x_k (Tensor): Reprentation of view intended for the key_encoder.
        returns:
            logit (Tensor): Positve and negative logits computed as by InfoNCE loss. (bsz, queue_size + 1)
            label (Tensor): Labels of the positve and negative logits to be used in softmax cross entropy. (bsz, 1)
        '''
        super(Select_object, self).__init__()

        self.queue_size = queue_size
        self.momentum = momentum
        self.temperature = temperature

        #assert self.queue_size % args.batch_size == 0  # for simplicity


        self.clip_image_encoder = clip_image_encoder
        self.logit_scale = logit_scale
        self.select_num = select_num

        # Add the mlp head


        # Initialize the key encoder to have the same values as query encoder
        # Do not update the key encoder via gradient


        # Create the queue to store negative samples
        self.register_buffer("queue", torch.randn(self.queue_size, hidden_size))

        self.register_buffer("queue_object", torch.randn(self.queue_size, 3, 224, 224))
        # Create pointer to store current position in the queue when enqueue and dequeue
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))




    @torch.no_grad()
    def update_queue(self, feat_k, maskrcnn_object_lists):
        '''
        Update the memory / queue.
        Add batch to end of most recent sample index and remove the oldest samples in the queue.
        Store location of most recent sample index (ptr).
        Taken from: https://github.com/facebookresearch/moco
        args:
            feat_k (Tensor): Feature reprentations of the view x_k computed by the key_encoder.
        '''

        batch_size = feat_k.size(0)

        ptr = int(self.queue_ptr)

        # replace the keys at ptr (dequeue and enqueue)
        if ptr + batch_size >= self.queue_size:
            ptr = 0
        self.queue[ptr:ptr + batch_size, :] = feat_k
        self.queue_object[ptr:ptr + batch_size, :] = maskrcnn_object_lists

        # move pointer along to end of current batch
        ptr = (ptr + batch_size) % self.queue_size

        # Store queue pointer as register_buffer
        self.queue_ptr[0] = ptr



    def forward(self, text_features, mask_rcnn_object_lists):
        batch_size = text_features.size(0)
        with torch.no_grad():
            object_image_features = self.clip_image_encoder(mask_rcnn_object_lists)    #[16, 512]
            object_image_features = object_image_features / object_image_features.norm(dim=1, keepdim=True)

        sum_select_images = []
        select_objects = []
        for b in range(batch_size):
            text_feature = text_features[b,:]
            scores = self.logit_scale * self.queue @ text_feature.t()
            K_max, indices = scores.topk(self.select_num, 0, True, True)
            select_object = self.queue_object[indices]
            select_objects.append(select_object)
        #     sum_select_image = torch.zeros((3, 224, 224)).cuda()
        #     for i in range(select_object.size(0)):
        #         sum_select_image += select_object[i, :, :, :]
        #     sum_select_images.append(sum_select_image)
        # sum_select_images = torch.stack(sum_select_images) #[4, 3, 224, 224]
        select_objects = torch.stack(select_objects)

        self.update_queue(feat_k=object_image_features, maskrcnn_object_lists=mask_rcnn_object_lists)

        return select_objects






